"""
    Functions that combine the unconditional score with a reward model for guidance.
"""

# pylint:disable=missing-class-docstring

# import math
from typing import Union

import torch

# from torch.distributions import Normal


class ScoreRewardModel:
    def __init__(
        self,
        condition: Union[float, torch.Tensor],
        reward_model: torch.nn.Module,
        score_model: torch.nn.Module,
        variance: float,
        alpha: float = 1.0,
    ):
        self.condition = condition
        self.reward_model = reward_model
        self.score_model = score_model
        self.variance = variance
        self.alpha = alpha
        self.type = "combined"

    def __call__(
        self,
        time: torch.Tensor,
        x_batch: torch.Tensor,
    ) -> torch.Tensor:
        with torch.enable_grad():
            x_batch = x_batch.clone().detach().requires_grad_(True)
            reward = self.reward_model(x_batch=x_batch, time=time)
            grad_reward = torch.autograd.grad(reward.sum(), x_batch, create_graph=True)[
                0
            ]
            self.reward_model.zero_grad()
        # Detach x_batch to prevent further gradient computations
        x_batch = x_batch.clone().detach().requires_grad_(False)
        # Compute the first part of the score
        score_part1 = self.score_model(time, x_batch)
        # Process condition
        if isinstance(self.condition, torch.Tensor):
            condition = self.condition.view(-1)  # Flatten if necessary
        else:
            condition = torch.tensor(
                self.condition, dtype=x_batch.dtype, device=x_batch.device
            )
        # Process reward
        if isinstance(reward, torch.Tensor):
            reward = reward.view(-1)  # Flatten if necessary
        else:
            reward = torch.tensor(reward, dtype=x_batch.dtype, device=x_batch.device)
        # Compute the second part of the score
        score_part2 = (
            self.alpha
            * grad_reward
            * ((condition[:, None] - reward[:, None]) / self.variance)
        )
        score = score_part1 + score_part2
        return score.clone().detach()

    def set_alpha(self, alpha: float):
        """set_alpha"""
        self.alpha = alpha

    def set_variance(self, variance: float):
        """set_variance"""
        self.variance = variance


class ProteinScoreReward:
    def __init__(
        self,
        condition: Union[float, torch.Tensor],
        reward_model: torch.nn.Module,
        score_model: torch.nn.Module,
        variance: float,
        diffusion_process,
        alpha: float = 1.0,
    ):
        self.condition = condition
        self.reward_model = reward_model
        self.score_model = score_model
        self.variance = variance
        self.alpha = alpha
        self.type = "combined"
        self.diffusion_process = diffusion_process

    def __call__(
        self,
        time: torch.Tensor,
        x_batch: torch.Tensor,
    ) -> torch.Tensor:
        with torch.enable_grad():
            self.score_model.train()
            x_batch = x_batch.clone().detach().requires_grad_(True)
            var = self.diffusion_process.marginal_prob_std(time=time) ** 2
            mean = self.diffusion_process.marginal_prob_mean_factor(time=time)
            x_processed = (x_batch + var * self.score_model(time, x_batch)) / mean
            if self.reward_model.bin_centers is not None:
                reward = torch.sum(
                    self.reward_model(x_processed)[0] * self.reward_model.bin_centers,
                    axis=1,
                )
            else:
                reward = self.reward_model(x_processed)[0]
            grad_reward = torch.autograd.grad(
                reward.sum(), x_processed, create_graph=True
            )[0]
            self.reward_model.zero_grad()
            self.score_model.zero_grad()
            if torch.isnan(grad_reward).any():
                print("gradient contains nan values")
        self.score_model.eval()
        # Detach x_batch to prevent further gradient computations
        x_batch = x_batch.clone().detach().requires_grad_(False)
        # Compute the first part of the score
        score_part1 = self.score_model(time, x_batch)
        # print("score:", score_part1.mean())
        # Process condition
        if isinstance(self.condition, torch.Tensor):
            condition = self.condition.view(-1)  # Flatten if necessary
        else:
            condition = torch.tensor(
                self.condition, dtype=x_batch.dtype, device=x_batch.device
            )

        # Process reward
        if isinstance(reward, torch.Tensor):
            reward = reward.view(-1)  # Flatten if necessary
        else:
            reward = torch.tensor(reward, dtype=x_batch.dtype, device=x_batch.device)
        # Compute the second part of the score
        score_part2 = (
            self.alpha
            * grad_reward
            * ((condition[:, None] - reward[:, None]) / self.variance)
        )
        score = score_part1 + score_part2
        return score.clone().detach()

    def set_alpha(self, alpha: float):
        """set_alpha"""
        self.alpha = alpha

    def set_variance(self, variance: float):
        """set_variance"""
        self.variance = variance
